{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Bidirectional LSTM in Keras"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this notebook, we use a *bidirectional* LSTM to classify IMDB movie reviews by their sentiment."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Load dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "import keras\n",
    "from keras.datasets import imdb\n",
    "from keras.preprocessing.sequence import pad_sequences\n",
    "from keras.models import Sequential\n",
    "from keras.layers import Dense, Dropout, Embedding, SpatialDropout1D, LSTM\n",
    "from keras.layers.wrappers import Bidirectional # new! \n",
    "from keras.callbacks import ModelCheckpoint\n",
    "import os\n",
    "from sklearn.metrics import roc_auc_score \n",
    "import matplotlib.pyplot as plt \n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Set hyperparameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# output directory name:\n",
    "output_dir = 'model_output/biLSTM'\n",
    "\n",
    "# training:\n",
    "epochs = 6\n",
    "batch_size = 128\n",
    "\n",
    "# vector-space embedding: \n",
    "n_dim = 64 \n",
    "n_unique_words = 10000 \n",
    "max_review_length = 200 # doubled!\n",
    "pad_type = trunc_type = 'pre'\n",
    "drop_embed = 0.2 \n",
    "\n",
    "# LSTM layer architecture:\n",
    "n_lstm = 256 \n",
    "drop_lstm = 0.2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "(x_train, y_train), (x_valid, y_valid) = imdb.load_data(num_words=n_unique_words) # removed n_words_to_skip"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Preprocess data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "x_train = pad_sequences(x_train, maxlen=max_review_length, padding=pad_type, truncating=trunc_type, value=0)\n",
    "x_valid = pad_sequences(x_valid, maxlen=max_review_length, padding=pad_type, truncating=trunc_type, value=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "#### Design neural network architecture"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "model = Sequential()\n",
    "model.add(Embedding(n_unique_words, n_dim, input_length=max_review_length)) \n",
    "model.add(SpatialDropout1D(drop_embed))\n",
    "model.add(Bidirectional(LSTM(n_lstm, dropout=drop_lstm)))\n",
    "model.add(Dense(1, activation='sigmoid'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "embedding_1 (Embedding)      (None, 200, 64)           640000    \n",
      "_________________________________________________________________\n",
      "spatial_dropout1d_1 (Spatial (None, 200, 64)           0         \n",
      "_________________________________________________________________\n",
      "bidirectional_1 (Bidirection (None, 512)               657408    \n",
      "_________________________________________________________________\n",
      "dense_1 (Dense)              (None, 1)                 513       \n",
      "=================================================================\n",
      "Total params: 1,297,921\n",
      "Trainable params: 1,297,921\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "# LSTM layer parameters double due to both reading directions\n",
    "model.summary() "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Configure model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "modelcheckpoint = ModelCheckpoint(filepath=output_dir+\"/weights.{epoch:02d}.hdf5\")\n",
    "if not os.path.exists(output_dir):\n",
    "    os.makedirs(output_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Train!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 25000 samples, validate on 25000 samples\n",
      "Epoch 1/6\n",
      "25000/25000 [==============================] - 118s - loss: 0.5368 - acc: 0.7193 - val_loss: 0.3578 - val_acc: 0.8480\n",
      "Epoch 2/6\n",
      "25000/25000 [==============================] - 112s - loss: 0.2881 - acc: 0.8835 - val_loss: 0.3141 - val_acc: 0.8719\n",
      "Epoch 3/6\n",
      "25000/25000 [==============================] - 114s - loss: 0.2198 - acc: 0.9164 - val_loss: 0.3167 - val_acc: 0.8651\n",
      "Epoch 4/6\n",
      "25000/25000 [==============================] - 114s - loss: 0.1772 - acc: 0.9344 - val_loss: 0.3469 - val_acc: 0.8598\n",
      "Epoch 5/6\n",
      "25000/25000 [==============================] - 114s - loss: 0.1492 - acc: 0.9445 - val_loss: 0.3802 - val_acc: 0.8676\n",
      "Epoch 6/6\n",
      "25000/25000 [==============================] - 115s - loss: 0.1260 - acc: 0.9536 - val_loss: 0.4241 - val_acc: 0.8620\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<keras.callbacks.History at 0x7f17a912ea20>"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# - we see 87.0% validation accuracy in epoch 2\n",
    "# - with this toy dataset, the complex interplay of words over long sentence segments, won't be learned much\n",
    "# - so our CNN picking up location-invariant segments of two to four words that predict review sentiment\n",
    "# - these are simpler and so easier to learn from the data\n",
    "# - CNN therefore outperforms on the IMDB data set\n",
    "model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, verbose=1, validation_data=(x_valid, y_valid), callbacks=[modelcheckpoint])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "#### Evaluate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "model.load_weights(output_dir+\"/weights.01.hdf5\") # zero-indexed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "25000/25000 [==============================] - 92s    \n"
     ]
    }
   ],
   "source": [
    "y_hat = model.predict_proba(x_valid)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYcAAAD8CAYAAACcjGjIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAERpJREFUeJzt3X+sX3V9x/HnSyr+VooU51q2i7E6kWQRbxBn4pw1/NJY\n/hBTM0clzZo45pwzm7gtYUFJ0P1ASfyxTjrBOIExMxrFkQ4hbotFL+JUYIQOGHQwudpadcQf1ff+\n+H5wVz63vd/e773329s+H0nzPedzPuec94dbeN3zOed7SFUhSdJMTxh3AZKkQ4/hIEnqGA6SpI7h\nIEnqGA6SpI7hIEnqGA6SpI7hIEnqzBkOSbYmeSTJN2a0HZtke5J72ufK1p4klyfZmeRrSU6Zsc/G\n1v+eJBtntL80ydfbPpcnyUIPUpJ0cDLXN6STvBL4PnBVVZ3c2t4P7K6qS5NcCKysqnclORt4G3A2\n8DLgg1X1siTHAlPAJFDAbcBLq2pPki8Bbwd2ADcAl1fV5+Yq/LjjjquJiYl5DVpaNN+9e/D5zBeO\ntw5pFrfddtu3qmrVMH1XzNWhqr6QZOJxzeuBV7XlK4FbgHe19qtqkDg7khyT5Lmt7/aq2g2QZDtw\nZpJbgGdW1Rdb+1XAOcCc4TAxMcHU1NRc3aSl9c+vGny+5pZxViHNKsl/Ddt3vvccnlNVDwO0z+Nb\n+2rgwRn9drW2A7XvmqVdkjRGC31Derb7BTWP9tkPnmxOMpVkanp6ep4lSpLmMt9w+GabLqJ9PtLa\ndwEnzOi3BnhojvY1s7TPqqq2VNVkVU2uWjXUtJkkaR7mGw7bgMeeONoIXD+j/bz21NJpwN427XQj\ncHqSle3JptOBG9u27yU5rT2ldN6MY0mSxmTOG9JJPsXghvJxSXYBFwGXAtcm2QQ8AJzbut/A4Eml\nncCjwPkAVbU7yXuAL7d+Fz92cxp4K/Bx4CkMbkTPeTNakrS4hnla6U372bRulr4FXLCf42wFts7S\nPgWcPFcdkqSl4zekJUkdw0GS1DEcJEmdOe85HI4mLvzsWM57/6WvHct5JelgeeUgSeoYDpKkjuEg\nSeoYDpKkjuEgSeoYDpKkjuEgSeoYDpKkjuEgSeoYDpKkjuEgSeoYDpKkjuEgSeoYDpKkjuEgSeoY\nDpKkjuEgSeoYDpKkjuEgSeoYDpKkjuEgSeoYDpKkjuEgSeoYDpKkjuEgSeoYDpKkjuEgSeoYDpKk\njuEgSeoYDpKkzkjhkOQdSe5I8o0kn0ry5CQnJrk1yT1JrklydOv7pLa+s22fmHGcd7f2u5OcMdqQ\nJEmjmnc4JFkN/B4wWVUnA0cBG4D3AZdV1VpgD7Cp7bIJ2FNVzwcua/1IclLb78XAmcCHkxw137ok\nSaMbdVppBfCUJCuApwIPA68GrmvbrwTOacvr2zpt+7okae1XV9UPq+o+YCdw6oh1SZJGMO9wqKr/\nBv4CeIBBKOwFbgO+U1X7WrddwOq2vBp4sO27r/V/9sz2WfaRJI3BKNNKKxn81n8i8IvA04CzZula\nj+2yn237a5/tnJuTTCWZmp6ePviiJUlDGWVa6TXAfVU1XVU/Bj4N/BpwTJtmAlgDPNSWdwEnALTt\nzwJ2z2yfZZ+fU1VbqmqyqiZXrVo1QumSpAMZJRweAE5L8tR272AdcCdwM/CG1mcjcH1b3tbWads/\nX1XV2je0p5lOBNYCXxqhLknSiFbM3WV2VXVrkuuArwD7gNuBLcBngauTvLe1XdF2uQL4RJKdDK4Y\nNrTj3JHkWgbBsg+4oKp+Mt+6JEmjm3c4AFTVRcBFj2u+l1meNqqqHwDn7uc4lwCXjFKLJGnh+A1p\nSVLHcJAkdQwHSVLHcJAkdQwHSVLHcJAkdQwHSVLHcJAkdQwHSVLHcJAkdQwHSVLHcJAkdQwHSVLH\ncJAkdUZ6ZbckHakmLvzsWM57/6WvXZLzeOUgSeoYDpKkjuEgSeoYDpKkjuEgSeoYDpKkjuEgSeoY\nDpKkjuEgSeoYDpKkjuEgSeoYDpKkjuEgSeoYDpKkjuEgSeoYDpKkjuEgSeoYDpKkjuEgSeqMFA5J\njklyXZL/SHJXkpcnOTbJ9iT3tM+VrW+SXJ5kZ5KvJTllxnE2tv73JNk46qAkSaMZ9crhg8A/VdWv\nAL8K3AVcCNxUVWuBm9o6wFnA2vZnM/ARgCTHAhcBLwNOBS56LFAkSeMx73BI8kzglcAVAFX1o6r6\nDrAeuLJ1uxI4py2vB66qgR3AMUmeC5wBbK+q3VW1B9gOnDnfuiRJoxvlyuF5wDTwt0luT/KxJE8D\nnlNVDwO0z+Nb/9XAgzP239Xa9tcuSRqTUcJhBXAK8JGqegnwv/z/FNJsMktbHaC9P0CyOclUkqnp\n6emDrVeSNKRRwmEXsKuqbm3r1zEIi2+26SLa5yMz+p8wY/81wEMHaO9U1ZaqmqyqyVWrVo1QuiTp\nQOYdDlX1P8CDSV7YmtYBdwLbgMeeONoIXN+WtwHntaeWTgP2tmmnG4HTk6xsN6JPb22SpDFZMeL+\nbwM+meRo4F7gfAaBc22STcADwLmt7w3A2cBO4NHWl6raneQ9wJdbv4uraveIdUmSRjBSOFTVV4HJ\nWTatm6VvARfs5zhbga2j1CJJWjh+Q1qS1DEcJEkdw0GS1DEcJEkdw0GS1DEcJEkdw0GS1DEcJEkd\nw0GS1DEcJEkdw0GS1DEcJEkdw0GS1DEcJEkdw0GS1DEcJEkdw0GS1DEcJEkdw0GS1DEcJEkdw0GS\n1DEcJEkdw0GS1DEcJEkdw0GS1DEcJEkdw0GS1DEcJEkdw0GS1DEcJEkdw0GS1DEcJEkdw0GS1DEc\nJEmdkcMhyVFJbk/ymbZ+YpJbk9yT5JokR7f2J7X1nW37xIxjvLu1353kjFFrkiSNZiGuHN4O3DVj\n/X3AZVW1FtgDbGrtm4A9VfV84LLWjyQnARuAFwNnAh9OctQC1CVJmqeRwiHJGuC1wMfaeoBXA9e1\nLlcC57Tl9W2dtn1d678euLqqflhV9wE7gVNHqUuSNJpRrxw+APwR8NO2/mzgO1W1r63vAla35dXA\ngwBt+97W/2fts+wjSRqDeYdDktcBj1TVbTObZ+lac2w70D6PP+fmJFNJpqanpw+qXknS8Ea5cngF\n8Pok9wNXM5hO+gBwTJIVrc8a4KG2vAs4AaBtfxawe2b7LPv8nKraUlWTVTW5atWqEUqXJB3IvMOh\nqt5dVWuqaoLBDeXPV9VvAjcDb2jdNgLXt+VtbZ22/fNVVa19Q3ua6URgLfCl+dYlSRrdirm7HLR3\nAVcneS9wO3BFa78C+ESSnQyuGDYAVNUdSa4F7gT2ARdU1U8WoS5J0pAWJByq6hbglrZ8L7M8bVRV\nPwDO3c/+lwCXLEQtkqTR+Q1pSVLHcJAkdQwHSVLHcJAkdQwHSVLHcJAkdQwHSVLHcJAkdQwHSVLH\ncJAkdQwHSVLHcJAkdQwHSVLHcJAkdQwHSVLHcJAkdQwHSVLHcJAkdQwHSVLHcJAkdQwHSVLHcJAk\ndQwHSVLHcJAkdQwHSVLHcJAkdQwHSVLHcJAkdQwHSVLHcJAkdQwHSVLHcJAkdQwHSVLHcJAkdVbM\nd8ckJwBXAb8A/BTYUlUfTHIscA0wAdwPvLGq9iQJ8EHgbOBR4C1V9ZV2rI3An7ZDv7eqrpxvXZKO\nHBMXfnbcJRy2Rrly2Ae8s6peBJwGXJDkJOBC4KaqWgvc1NYBzgLWtj+bgY8AtDC5CHgZcCpwUZKV\nI9QlSRrRvMOhqh5+7Df/qvoecBewGlgPPPab/5XAOW15PXBVDewAjknyXOAMYHtV7a6qPcB24Mz5\n1iVJGt2C3HNIMgG8BLgVeE5VPQyDAAGOb91WAw/O2G1Xa9tfuyRpTEYOhyRPB/4B+P2q+u6Bus7S\nVgdon+1cm5NMJZmanp4++GIlSUMZKRySPJFBMHyyqj7dmr/Zpoton4+09l3ACTN2XwM8dID2TlVt\nqarJqppctWrVKKVLkg5g3uHQnj66Arirqv5qxqZtwMa2vBG4fkb7eRk4Ddjbpp1uBE5PsrLdiD69\ntUmSxmTej7ICrwB+C/h6kq+2tj8GLgWuTbIJeAA4t227gcFjrDsZPMp6PkBV7U7yHuDLrd/FVbV7\nhLokSSOadzhU1b8y+/0CgHWz9C/ggv0cayuwdb61SJIWlt+QliR1DAdJUsdwkCR1DAdJUsdwkCR1\nDAdJUmeU7znoII3z9cL3X/rasZ1b0vLjlYMkqWM4SJI6hoMkqWM4SJI6hoMkqWM4SJI6hoMkqeP3\nHCSNbJzf4dHi8MpBktQxHCRJHcNBktQxHCRJHW9IHyHGdcPQF/5Jy5NXDpKkjlcO0mHCx0m1kLxy\nkCR1DAdJUsdpJS2qI/H/frfj3m+zwSkeLXOGgw5b4wimq5/37SU/p7QYnFaSJHUMB0lSx3CQJHUM\nB0lSx3CQJHUMB0lSx3CQJHUMB0lS55AJhyRnJrk7yc4kF467Hkk6kh0S4ZDkKOBDwFnAScCbkpw0\n3qok6ch1SIQDcCqws6ruraofAVcD68dckyQdsQ6VcFgNPDhjfVdrkySNwaHy4r3M0lZdp2QzsLmt\nfj/J3fM833HAt+a573LlmJfAy3+29LqlPO1j/BkfAfK+kcb8y8N2PFTCYRdwwoz1NcBDj+9UVVuA\nLaOeLMlUVU2OepzlxDEf/o608YJjXkyHyrTSl4G1SU5McjSwAdg25pok6Yh1SFw5VNW+JL8L3Agc\nBWytqjvGXJYkHbEOiXAAqKobgBuW6HQjT00tQ4758HekjRcc86JJVXffV5J0hDtU7jlIkg4hh204\nzPU6jiRPSnJN235rkomlr3JhDTHmP0hyZ5KvJbkpydCPtR2qhn3tSpI3JKkky/7JlmHGnOSN7Wd9\nR5K/W+oaF9oQf7d/KcnNSW5vf7/PHkedCynJ1iSPJPnGfrYnyeXtn8nXkpyyoAVU1WH3h8FN7f8E\nngccDfw7cNLj+vwO8NG2vAG4Ztx1L8GYfwN4alt+65Ew5tbvGcAXgB3A5LjrXoKf81rgdmBlWz9+\n3HUvwZi3AG9tyycB94+77gUY9yuBU4Bv7Gf72cDnGHxP7DTg1oU8/+F65TDM6zjWA1e25euAdUlm\n+zLecjHnmKvq5qp6tK3uYPB9kuVs2NeuvAd4P/CDpSxukQwz5t8GPlRVewCq6pElrnGhDTPmAp7Z\nlp/FLN+TWm6q6gvA7gN0WQ9cVQM7gGOSPHehzn+4hsMwr+P4WZ+q2gfsBZ69JNUtjoN9BckmBr91\nLGdzjjnJS4ATquozS1nYIhrm5/wC4AVJ/i3JjiRnLll1i2OYMf8Z8OYkuxg89fi2pSltrBb1tUOH\nzKOsC2yY13EM9cqOZWTo8SR5MzAJ/PqiVrT4DjjmJE8ALgPeslQFLYFhfs4rGEwtvYrB1eG/JDm5\nqr6zyLUtlmHG/Cbg41X1l0leDnyijfmni1/e2Czqf8MO1yuHYV7H8bM+SVYwuBQ90CXcoW6oV5Ak\neQ3wJ8Drq+qHS1TbYplrzM8ATgZuSXI/g3nZbcv8pvSwf7evr6ofV9V9wN0MwmK5GmbMm4BrAarq\ni8CTGbx36XA21L/z83W4hsMwr+PYBmxsy28APl/tLs8yNeeY2xTLXzMIhuU+Dw1zjLmq9lbVcVU1\nUVUTDO6zvL6qpsZT7oIY5u/2PzJ4+IAkxzGYZrp3SatcWMOM+QFgHUCSFzEIh+klrXLpbQPOa08t\nnQbsraqHF+rgh+W0Uu3ndRxJLgamqmobcAWDS8+dDK4YNoyv4tENOeY/B54O/H279/5AVb1+bEWP\naMgxH1aGHPONwOlJ7gR+AvxhVX17fFWPZsgxvxP4myTvYDC18pZl/sseST7FYGrwuHYv5SLgiQBV\n9VEG91bOBnYCjwLnL+j5l/k/P0nSIjhcp5UkSSMwHCRJHcNBktQxHCRJHcNBktQxHCRJHcNBktQx\nHCRJnf8DaFnnSTEPqCgAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x7f17a90b0160>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.hist(y_hat)\n",
    "_ = plt.axvline(x=0.5, color='orange')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'94.39'"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\"{:0.2f}\".format(roc_auc_score(y_valid, y_hat)*100.0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
